Skip to content

Fix DP coordinator ZMQ port TOCTOU#37452

Merged
tlrmchlsmth merged 3 commits into
vllm-project:mainfrom
itayalroy:zmq_toctou
Mar 20, 2026
Merged

Fix DP coordinator ZMQ port TOCTOU#37452
tlrmchlsmth merged 3 commits into
vllm-project:mainfrom
itayalroy:zmq_toctou

Conversation

@itayalroy

@itayalroy itayalroy commented Mar 18, 2026

Copy link
Copy Markdown
Contributor

Previously the parent selected the DP coordinator's TCP ZMQ ports with
get_open_port() before the coordinator actually bound them, leaving a
window where another socket could claim the ports.

Fix this by letting the coordinator bind first and report the bound ZMQ
addresses back to the parent via pipe.

Signed-off-by: Itay Alroy <ialroy@nvidia.com>
@itayalroy itayalroy requested a review from njhill as a code owner March 18, 2026 16:02
@mergify mergify Bot added the v1 label Mar 18, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a TOCTOU (Time-of-Check Time-of-Use) vulnerability in the DP coordinator's ZMQ port selection. Previously, the parent process selected the ports before the coordinator bound them, creating a window for other processes to claim those ports. This is fixed by having the coordinator bind the ports first and then report the bound addresses back to the parent via a pipe. The changes involve modifications to network_utils.py and coordinator.py to implement this new mechanism.

Comment on lines +60 to +73
ready = multiprocessing.connection.wait(
[zmq_addr_pipe, self.proc.sentinel], timeout=30
)
if not ready:
raise RuntimeError(
"DP Coordinator process failed to report ZMQ addresses "
"during startup."
)
try:
return zmq_addr_pipe.recv()
except EOFError:
raise RuntimeError(
"DP Coordinator process failed during startup."
) from None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _wait_for_zmq_addrs method includes a timeout of 30 seconds. If the DP Coordinator process fails to report ZMQ addresses within this time, a RuntimeError is raised. However, there's no mechanism to handle or retry this failure. Consider adding a retry mechanism or a more robust error handling strategy to improve the resilience of the system. This is a critical issue because a failure here will prevent the engine from starting up.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fatal IMO, if the DP Coordinator cannot report ZMQ addresses within 30 seconds it is reasonable to fail

@zch42 zch42 May 6, 2026

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@itayalroy @tlrmchlsmth can we make the timeout configurable? The current 30s limit can be too short, for example when spawn is forced, the child process will re-import many modules

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same issue. from vllm.v1.engine import coordinator takes 70+ seconds to import.

Comment on lines +120 to +125
child_zmq_addr_pipe.close()
(
front_publish_address,
back_output_address,
back_publish_address,
) = self._wait_for_zmq_addrs(parent_zmq_addr_pipe)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

After starting the coordinator process, the parent process retrieves the bound ZMQ addresses using self._wait_for_zmq_addrs. However, if self._wait_for_zmq_addrs fails, the addresses used to initialize self.stats_publish_address and self.coord_in_address will be the original, unbound addresses. This could lead to the parent process attempting to communicate with the coordinator on the wrong ports. This is a critical issue because it can lead to communication failures between the parent and coordinator processes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. If _wait_for_zmq_addrs() fails, we raise an exception, so we never proceed using wrong ports.

Comment on lines +93 to +98
def bind_address(local_only: bool) -> str:
return (
get_engine_client_zmq_addr(local_only=True, host=host)
if local_only
else get_tcp_uri(host, 0)
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The bind_address function uses get_engine_client_zmq_addr when local_only is true, which returns an IPC path. However, when local_only is false, it uses get_tcp_uri with port 0, which requests the OS to assign a port. This inconsistency in address types (IPC vs. TCP) could lead to unexpected behavior or configuration issues. Ensure that the address type is consistent based on the deployment environment or configuration. This is a high severity issue because it can lead to connectivity problems.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inconsistency in address types (IPC/TCP) already exists, the only change is that with TCP we now let the OS assign the port on bind time instead of binding to a pre-chosen port that might be already taken

Comment thread vllm/v1/engine/coordinator.py Outdated

@tlrmchlsmth tlrmchlsmth left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, @njhill if you can take a look, it'd be good to get another pair of eyes.

@itayalroy do you think there's a reasonable way to unit test this?

Comment thread vllm/v1/engine/coordinator.py
Comment thread vllm/v1/engine/coordinator.py Outdated
@tlrmchlsmth tlrmchlsmth self-assigned this Mar 19, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 19, 2026
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) March 19, 2026 22:15
@tlrmchlsmth tlrmchlsmth merged commit ca1ac1a into vllm-project:main Mar 20, 2026
51 of 52 checks passed
chooper26 pushed a commit to vLLM-HUST/vllm-hust that referenced this pull request Mar 21, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
nithinvc pushed a commit to nithinvc/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>

Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
RTCartist added a commit to RTCartist/vllm that referenced this pull request Jun 4, 2026
Replace `get_open_port()` with late binding (port 0) for the remote
XPUB socket in `MessageQueue.__init__`, then read back the actual
bound address via `zmq.LAST_ENDPOINT`. This eliminates the window
between port discovery and socket bind where another process could
claim the port.

Follows the same pattern already used in the DP coordinator
(PR vllm-project#37452).

Closes vllm-project#28498

Signed-off-by: RTCartist <wangshengb@buaa.edu.cn>
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants